Deep Q-Networks (DQN)
Deep Q-Networks (DQN) represent a breakthrough in reinforcement learning that combines Q-learning with deep neural networks to handle high-dimensional state spaces. Introduced by DeepMind in 2013-2015, DQN successfully learned to play Atari games directly from pixel input, demonstrating superhuman performance in many games.
The Limitation of Tabular Q-Learning
Traditional Q-learning maintains a table of Q-values for every state-action pair. This approach becomes impractical when:
- The state space is high-dimensional (e.g., images with thousands of pixels)
- The state space is continuous (infinite possible states)
- The problem requires generalization between similar states
Core Concept of DQN
DQN replaces the Q-table with a deep neural network that approximates the Q-function:
- Input: State representation (e.g., raw pixels or features)
- Output: Q-values for each possible action
- Training: Minimize the difference between predicted Q-values and target Q-values
Key Innovations in DQN
1. Experience Replay
- Store experience tuples (s, a, r, s') in a replay buffer
- Sample random mini-batches from this buffer for training
- Benefits:
- Breaks correlations between consecutive experiences
- Uses each experience for multiple updates
- Better data efficiency by reusing transitions
2. Target Network
- Maintain two networks:
- Q-network: Updated at each iteration
- Target Q-network: Periodically updated copy of Q-network
- Use target network for calculating target Q-values
- Reduces oscillations and improves stability by making targets more stationary
3. Convolutional Architecture
- For visual inputs, use convolutional neural networks
- Automatically learns useful features from raw pixels
- Captures spatial relationships and patterns
The DQN Algorithm
-
Initialize:
- Q-network with random weights θ
- Target network with weights θ⁻ = θ
- Replay buffer D with capacity N
-
For each episode:
-
Initialize state s₁
-
For each step t:
-
With probability ε, select random action aₜ (exploration)
-
Otherwise, select aₜ = argmax Q(sₜ, a; θ) (exploitation)
-
Execute action aₜ, observe reward rₜ and next state sₜ₊₁
-
Store transition (sₜ, aₜ, rₜ, sₜ₊₁) in replay buffer D
-
Sample mini-batch of transitions (sⱼ, aⱼ, rⱼ, sⱼ₊₁) from D
-
Set target:
- If sⱼ₊₁ is terminal: yⱼ = rⱼ
- Otherwise: yⱼ = rⱼ + γ max Q(sⱼ₊₁, a'; θ⁻)
-
Update Q-network:
- Perform gradient descent step on (yⱼ - Q(sⱼ, aⱼ; θ))²
-
Periodically update target network:
- Every C steps, set θ⁻ = θ
-
-
Mathematical Formulation
The loss function for DQN is:
Where:
- θ are the parameters of the Q-network
- θ⁻ are the parameters of the target network
- D is the replay buffer
- γ is the discount factor
Implementation Example
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
class DQN(nn.Module):
def __init__(self, input_shape, num_actions):
super(DQN, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU()
)
conv_out_size = self._get_conv_output(input_shape)
self.fc = nn.Sequential(
nn.Linear(conv_out_size, 512),
nn.ReLU(),
nn.Linear(512, num_actions)
)
def _get_conv_output(self, shape):
o = self.conv(torch.zeros(1, *shape))
return int(np.prod(o.size()))
def forward(self, x):
conv_out = self.conv(x).view(x.size()[0], -1)
return self.fc(conv_out)
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
return random.sample(self.buffer, batch_size)
def __len__(self):
return len(self.buffer)
def train_dqn(env, num_episodes, gamma=0.99, epsilon_start=1.0, epsilon_final=0.01,
epsilon_decay=10000, batch_size=32, target_update=1000):
# Initialize networks, optimizer, replay buffer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
q_network = DQN(env.observation_space.shape, env.action_space.n).to(device)
target_network = DQN(env.observation_space.shape, env.action_space.n).to(device)
target_network.load_state_dict(q_network.state_dict())
optimizer = optim.Adam(q_network.parameters(), lr=0.0001)
replay_buffer = ReplayBuffer(100000)
steps_done = 0
episode_rewards = []
for episode in range(num_episodes):
state = env.reset()
episode_reward = 0
done = False
while not done:
# Epsilon-greedy action selection
epsilon = epsilon_final + (epsilon_start - epsilon_final) * \
np.exp(-1. * steps_done / epsilon_decay)
steps_done += 1
if random.random() < epsilon:
action = env.action_space.sample()
else:
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
q_values = q_network(state_tensor)
action = q_values.max(1)[1].item()
# Execute action and store transition
next_state, reward, done, _ = env.step(action)
replay_buffer.push(state, action, reward, next_state, done)
episode_reward += reward
state = next_state
# Train if enough samples in replay buffer
if len(replay_buffer) > batch_size:
# Sample batch and convert to tensors
batch = replay_buffer.sample(batch_size)
state_batch = torch.FloatTensor([s[0] for s in batch]).to(device)
action_batch = torch.LongTensor([s[1] for s in batch]).unsqueeze(1).to(device)
reward_batch = torch.FloatTensor([s[2] for s in batch]).to(device)
next_state_batch = torch.FloatTensor([s[3] for s in batch]).to(device)
done_batch = torch.FloatTensor([s[4] for s in batch]).to(device)
# Compute current Q values
current_q_values = q_network(state_batch).gather(1, action_batch)
# Compute target Q values
with torch.no_grad():
max_next_q_values = target_network(next_state_batch).max(1)[0]
target_q_values = reward_batch + gamma * max_next_q_values * (1 - done_batch)
# Compute loss and update
loss = nn.functional.smooth_l1_loss(current_q_values, target_q_values.unsqueeze(1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Update target network
if steps_done % target_update == 0:
target_network.load_state_dict(q_network.state_dict())
episode_rewards.append(episode_reward)
# Log progress
if episode % 10 == 0:
print(f"Episode {episode}, Avg Reward: {np.mean(episode_rewards[-10:])}, Epsilon: {epsilon:.2f}")
return q_network
Extensions and Improvements to DQN
Double DQN (DDQN)
- Addresses overestimation bias in Q-learning
- Uses online network to select actions but target network to evaluate them
- Target: yⱼ = rⱼ + γ Q(sⱼ₊₁, argmax Q(sⱼ₊₁, a; θ); θ⁻)
Dueling DQN
- Separates value and advantage functions
- Architecture has two streams that share early layers:
- One stream estimates state value V(s)
- Another estimates advantage A(s,a)
- Combines them: Q(s,a) = V(s) + (A(s,a) - mean(A(s,:)))
- Better at identifying value of specific actions
Prioritized Experience Replay
- Samples important transitions more frequently
- Transitions with higher TD error have higher priority
- Corrects sampling bias with importance sampling weights
Rainbow DQN
- Combines multiple improvements to DQN:
- Double Q-learning
- Prioritized experience replay
- Dueling networks
- Multi-step learning
- Distributional RL
- Noisy networks for exploration
Applications of DQN
- Game Playing: Atari games, board games, more complex games
- Robotics: Learning control policies from visual input
- Autonomous Systems: Navigation, obstacle avoidance
- Resource Management: Data center optimization, traffic control
- Healthcare: Treatment optimization
- Finance: Trading strategies
Limitations of DQN
- Sample Inefficiency: Requires many interactions with the environment
- Hyperparameter Sensitivity: Performance depends on hyperparameter tuning
- Discrete Action Spaces: Original DQN only works with discrete actions
- Convergence Issues: May be unstable in some environments
- Overestimation Bias: Tends to overestimate Q-values
Recent Developments
- Distributed DQN: Parallel actors for more efficient data collection
- Episodic Memory: Incorporating episodic memory for faster learning
- Self-supervised Learning: Using self-supervised objectives to improve representations
- Curiosity-driven Exploration: Intrinsic motivation for better exploration
- Meta-learning: Learning to learn across multiple tasks
Deep Q-Networks represent a fundamental breakthrough in reinforcement learning, allowing RL algorithms to scale to complex, high-dimensional problems. Their success has spurred tremendous research activity and real-world applications in deep reinforcement learning.